# damcts.py
# Dimension-Adaptive MCTS (DAMCTS) with state-dict snapshots (no full-env pickling)
import faulthandler; faulthandler.enable()

import math
import copy
import gym
import random
import numpy as np
import statistics

# Local modules / envs
import improved_walker2d  # ensure registration
import improved_ant
import improved_humanoid
from SnapshotENV import SnapshotEnv

# --------------------------
# Global config
# --------------------------
discount = 0.99
TEST_ITERATIONS = 150
MAX_DAMCTS_DEPTH = 100

# Polynomial bonus: C * N_parent^(1/4) / N_child^(1/2)
BONUS_C = 30.0
PARENT_EXP = 0.25
CHILD_EXP = 0.5

# Power-mean backup
POWER = 2.0

# Dimension-adaptive discretization
EPSILON_1 = 0.5
BETA = 1.0
L_HOLDER = 1.0  # Hölder constant used inside epsilon bonus

# Reward shaping (planning-time shift/scale; evaluation uses true reward)
REWARD_SCALING = 1.0
REWARD_OFFSETS = {
    "ImprovedWalker2d-v0": 20.0,
    "ImprovedAnt-v0": 20.0,   # NEW
    "ImprovedHumanoid-v0": 25.0
}

# Noise configs
ENV_NOISE_CONFIG = {
    "ImprovedWalker2d-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01,
    },
    "ImprovedAnt-v0": {       # NEW
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01,
    },
    "ImprovedHumanoid-v0": {
        "action_noise_scale": 0.03,
        "dynamics_noise_scale": 0.02,
        "obs_noise_scale": 0.01,
    },
}

# --------------------------
# Epsilon-net construction
# --------------------------
def build_epsilon_net(env_name, action_dim, epsilon, lo=-1.0, hi=1.0):
    """
    Build an epsilon-net for actions.
    For d <= 4: grid; for higher d: randomized net with practical caps.
    """
    if action_dim <= 4:
        per_dim = int(np.ceil((1.0 / max(epsilon, 1e-3)) ** (1.0 / action_dim)))
        per_dim = int(np.clip(per_dim, 2, 25))
        axes = [np.linspace(lo, hi, per_dim, dtype=np.float32) for _ in range(action_dim)]
        mesh = np.meshgrid(*axes, indexing="ij")
        points = np.stack([m.ravel() for m in mesh], axis=-1)
        return [p.astype(np.float32, copy=False) for p in points]
    else:
        # randomized net with cap
        cap = 1500 if action_dim <= 8 else 2000
        n_samples = int(np.clip((1.0 / max(epsilon, 1e-3)) ** min(action_dim, 2), 100, cap))
        samples = np.random.uniform(lo, hi, size=(n_samples, action_dim)).astype(np.float32)
        return [s for s in samples]


def get_env_action_space(env_name):
    if env_name == "ImprovedWalker2d-v0":
        return 6, -1.0, 1.0
    if env_name == "ImprovedAnt-v0":  # NEW
        return 8, -1.0, 1.0
    if env_name == "ImprovedHumanoid-v0":  # NEW
        return 17, -1.0, 1.0
    return 1, -1.0, 1.0


# --------------------------
# DAMCTS Nodes
# --------------------------
class DAMCTSNode:
    __slots__ = (
        "parent",
        "action",
        "children",
        "visit_count",
        "value_sum",
        "value_sum_power",
        "snapshot",
        "obs",
        "immediate_reward",
        "is_done",
        "eps_net_func",
        "env_name",
        "_action_dim",
        "lo",
        "hi",
    )

    def __init__(self, parent, action, eps_net_func, env: SnapshotEnv, env_name, action_dim, lo, hi):
        self.parent = parent
        self.action = None if action is None else np.asarray(action, dtype=np.float32)
        self.children = {}  # key: tuple(action rounded), value: node
        self.visit_count = 0
        self.value_sum = 0.0
        self.value_sum_power = 0.0

        self.env_name = env_name
        self._action_dim = int(action_dim)
        self.lo, self.hi = float(lo), float(hi)

        # Environment interaction
        if parent is None:
            self.snapshot = None
            self.obs = None
            self.immediate_reward = 0.0
            self.is_done = False
        else:
            # FIXED: get_result now returns tuple (snapshot, obs, reward, done, info)
            snapshot, obs, reward, done, info = env.get_result(parent.snapshot, self.action)
            self.snapshot = snapshot
            self.obs = obs
            # plan-time shaping (to stabilize power-mean)
            r_offset = REWARD_OFFSETS.get(env_name, 10.0)
            self.immediate_reward = max(0.01, (reward + r_offset) * REWARD_SCALING)
            self.is_done = bool(done)

        self.eps_net_func = eps_net_func

    # Keys for action hashing with tolerance
    @staticmethod
    def _act_key(a: np.ndarray, ndigits=6):
        return tuple(np.round(a.astype(np.float32), decimals=ndigits))

    def is_root(self):
        return self.parent is None

    def get_action_dim(self):
        return self._action_dim

    def get_mean_value(self):
        return self.value_sum / self.visit_count if self.visit_count > 0 else 0.0

    def get_power_mean_value(self):
        if self.visit_count == 0:
            return 0.0
        return (self.value_sum_power / self.visit_count) ** (1.0 / POWER)

    # ---- Dimension-adaptive epsilon level ----
    def epsilon_level(self):
        n = max(self.visit_count, 1)
        d = self.get_action_dim()
        k = 1
        while True:
            eps_k = EPSILON_1 * (2.0 ** (-(k - 1) / (d + 2.0 * BETA)))
            if d <= 4:
                per_dim = max(2, int(np.ceil((1.0 / max(eps_k, 1e-6)) ** (1.0 / d))))
                size = min(25, per_dim) ** d
            else:
                # approximate size of randomized net
                size = int(np.clip((1.0 / max(eps_k, 1e-6)) ** 2, 100, 2000))
            if n <= size * size:
                return k, eps_k, size
            k += 1

    # ---- UCB with polynomial bonus + discretization bonus ----
    def ucb_score(self, child_node, parent_visits):
        if child_node is None or child_node.visit_count == 0:
            return float("inf")

        pm = child_node.get_power_mean_value()
        # discretization bias term
        _, eps_k, _ = self.epsilon_level()
        eps_bonus = L_HOLDER * (eps_k ** BETA)

        # polynomial exploration bonus
        if parent_visits > 0 and child_node.visit_count > 0:
            cb = BONUS_C * (parent_visits ** PARENT_EXP) / (child_node.visit_count ** CHILD_EXP)
        else:
            cb = float("inf")

        return pm + eps_bonus + cb

    def get_child(self, action):
        key = self._act_key(action)
        return self.children.get(key, None)

    # ---- Tree policy: selection to a leaf (or expandable node) ----
    def tree_policy(self, env: SnapshotEnv):
        node = self
        while True:
            if node.is_done:
                return node

            # Build current epsilon net
            _, eps_k, _ = node.epsilon_level()
            actions = node.eps_net_func(eps_k)

            # Expand if any discretized action hasn't been tried
            expanded = False
            for a in actions:
                if node.get_child(a) is None:
                    child = DAMCTSNode(
                        parent=node,
                        action=a,
                        eps_net_func=node.eps_net_func,
                        env=env,
                        env_name=node.env_name,
                        action_dim=node._action_dim,
                        lo=node.lo,
                        hi=node.hi,
                    )
                    node.children[self._act_key(a)] = child
                    return child  # stop after first expansion
                    # expanded = True  # (unreachable)
            # If all actions exist, pick best child by UCB and continue
            best_child = None
            best_score = -float("inf")
            for child in node.children.values():
                score = node.ucb_score(child, node.visit_count)
                if score > best_score:
                    best_score = score
                    best_child = child
            if best_child is None:
                # no available action (shouldn't happen unless terminal)
                return node
            node = best_child

    # ---- Default rollout: random actions for max_depth ----
    def rollout(self, env: SnapshotEnv, max_depth=MAX_DAMCTS_DEPTH):
        if self.is_done:
            return 0.0

        env.load_snapshot(self.snapshot)

        total = 0.0
        disc = 1.0
        r_offset = REWARD_OFFSETS.get(self.env_name, 10.0)
        a_dim = self.get_action_dim()

        for _ in range(max_depth):
            a = np.random.uniform(self.lo, self.hi, size=(a_dim,)).astype(np.float32)
            obs, r, done, _ = env.step(a)
            shaped = max(0.01, (r + r_offset) * REWARD_SCALING)
            total += shaped * disc
            disc *= discount
            if done:
                break
        return total

    # ---- Backprop with power-mean backup ----
    def backpropagate(self, rollout_return):
        # total return for this node is immediate + discounted rollout
        total_return = self.immediate_reward + rollout_return

        self.visit_count += 1
        self.value_sum += total_return
        self.value_sum_power += total_return ** POWER

        if not self.is_root():
            # child rollout was evaluated starting one step below,
            # so discount for the parent backup:
            self.parent.backpropagate(rollout_return * discount)


class DAMCTSRoot(DAMCTSNode):
    def __init__(self, snapshot, obs, eps_net_func, env_name, action_dim, lo, hi):
        super().__init__(
            parent=None,
            action=None,
            eps_net_func=eps_net_func,
            env=None,
            env_name=env_name,
            action_dim=action_dim,
            lo=lo,
            hi=hi,
        )
        self.snapshot = snapshot
        self.obs = obs
        self.immediate_reward = 0.0
        self.is_done = False


# --------------------------
# DAMCTS planning loop
# --------------------------
def plan_damcts(root: DAMCTSRoot, env: SnapshotEnv, n_iter: int):
    for _ in range(n_iter):
        leaf = root.tree_policy(env)
        if leaf.is_done:
            leaf.backpropagate(0.0)
            continue
        rollout_value = leaf.rollout(env, max_depth=MAX_DAMCTS_DEPTH)
        leaf.backpropagate(rollout_value)


# --------------------------
# Experiment runner (single env)
# --------------------------
def run_single_env(envname: str, seeds: int, iter_list):
    action_dim, lo, hi = get_env_action_space(envname)

    # Planning env
    base_env = gym.make(envname, disable_env_checker=True, **ENV_NOISE_CONFIG.get(envname, {})).env
    planning_env = SnapshotEnv(base_env)

    # Reset & snapshot
    root_obs = planning_env.reset()
    root_snapshot = planning_env.get_snapshot()

    out_lines = []
    for ITER in iter_list:
        seed_returns = []
        for seed in range(seeds):
            random.seed(seed)
            np.random.seed(seed)

            # fresh root state
            s0 = copy.deepcopy(root_obs)
            snap0 = copy.deepcopy(root_snapshot)

            # epsilon net function for this env
            def eps_net_func(epsilon):
                return build_epsilon_net(envname, action_dim, epsilon, lo, hi)

            root = DAMCTSRoot(
                snapshot=snap0,
                obs=s0,
                eps_net_func=eps_net_func,
                env_name=envname,
                action_dim=action_dim,
                lo=lo,
                hi=hi,
            )

            # plan
            plan_damcts(root, planning_env, n_iter=ITER)

            # test rollout using the SAME env type (through SnapshotEnv)
            test_env = SnapshotEnv(gym.make(envname, disable_env_checker=True, **ENV_NOISE_CONFIG.get(envname, {})).env)
            test_env.load_snapshot(snap0)

            total_reward = 0.0
            disc = 1.0
            done = False

            for t in range(TEST_ITERATIONS):
                if len(root.children) == 0:
                    best_action = np.random.uniform(lo, hi, size=(action_dim,)).astype(np.float32)
                    best_child = None
                else:
                    best_child = max(root.children.values(), key=lambda c: c.get_power_mean_value())
                    best_action = best_child.action

                obs, r, done, _ = test_env.step(best_action)
                total_reward += r * disc
                disc *= discount
                if done:
                    break

                # synchronize planning state with the new test state
                snap_now = test_env.get_snapshot()
                planning_env.load_snapshot(snap_now)

                # re-root the tree at the chosen child (or start fresh if none)
                if best_child is None:
                    root = DAMCTSRoot(
                        snapshot=snap_now,
                        obs=obs,
                        eps_net_func=eps_net_func,
                        env_name=envname,
                        action_dim=action_dim,
                        lo=lo,
                        hi=hi,
                    )
                else:
                    # Convert chosen child to new root (keep its stats/children)
                    new_root = DAMCTSRoot(
                        snapshot=best_child.snapshot,
                        obs=best_child.obs,
                        eps_net_func=best_child.eps_net_func,
                        env_name=best_child.env_name,
                        action_dim=best_child._action_dim,
                        lo=best_child.lo,
                        hi=best_child.hi,
                    )
                    new_root.children = best_child.children
                    new_root.visit_count = best_child.visit_count
                    new_root.value_sum = best_child.value_sum
                    new_root.value_sum_power = best_child.value_sum_power
                    new_root.is_done = best_child.is_done
                    root = new_root

                # re-plan
                plan_damcts(root, planning_env, n_iter=ITER)

            test_env.close()
            seed_returns.append(total_reward)

        mean_ret = statistics.mean(seed_returns)
        std_ret = statistics.pstdev(seed_returns)
        ci = 2.0 * std_ret
        line = f"Env={envname}, ITER={ITER}: Mean={mean_ret:.3f} ± {ci:.3f} (n={seeds})"
        print(line)
        out_lines.append(line)
    return out_lines


# --------------------------
# Main
# --------------------------
if __name__ == "__main__":
    # Example: run Walker2d
    # envs = ["ImprovedWalker2d-v0", "ImprovedAnt-v0"]
    # envs = ["ImprovedAnt-v0"]
    envs = ["ImprovedHumanoid-v0"]

    # Iteration schedule (matches your earlier geometric progression roughly)
    base = 1000 ** (1.0 / 15.0)
    samples = [int(3 * (base ** i)) for i in range(16)]
    iter_list = samples[0:6]  # e.g. 3, 4, 7, 11, 18, 29

    seeds = 20
    for envname in envs:
        lines = run_single_env(envname, seeds=seeds, iter_list=iter_list)
        # Append to a result file if desired
        with open("damcts_results.txt", "a") as f:
            for l in lines:
                f.write(l + "\n")
